Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow Scan logprob inference of non-pure RandomVariable outputs #6578

Merged
merged 1 commit into from
Mar 8, 2023

Conversation

ricardoV94
Copy link
Member

Most of the IR logprob rewrites require a PreserveRVMappings feature in the fgraph. The rewrite responsible to introduce IR in the inner graph of Scan was not adding this feature.

In addition, find_measurable_scans, was bailing out when there were MesurableVariable nodes that were not outputs, even if these were being used by downstream nodes as the source of measurability.

After this fix, both of the following work:

import numpy as np
import pytensor

import pymc as pm

grw1, _ = pytensor.scan(
  fn = lambda x: pm.Normal.dist(x),
  outputs_info=[pm.DiracDelta.dist(0.0)],
  n_steps=10,
)

grw2, _ = pytensor.scan(
  fn = lambda x: pm.Normal.dist() + x,
  outputs_info=[pm.DiracDelta.dist(0.0)],
  n_steps=10,
)

print(pm.logp(grw1, np.arange(10)).eval())
print(pm.logp(grw2, np.arange(10)).eval())

CC @jessegrabowski

Most of the IR logprob rewrites require a PreserveRVMappings feature in the fgraph. The rewrite responsible to introduce IR in the inner graph of Scan was not adding this feature.

In addition, find_measurable_scans, was bailing out when there were MesurableVariable nodes that were not outputs, even if these were being used by downstream nodes as the source of measurability.
@codecov
Copy link

codecov bot commented Mar 8, 2023

Codecov Report

Merging #6578 (684707c) into main (e76bba9) will increase coverage by 7.55%.
The diff coverage is 100.00%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #6578      +/-   ##
==========================================
+ Coverage   84.45%   92.01%   +7.55%     
==========================================
  Files          91       91              
  Lines       15108    15114       +6     
==========================================
+ Hits        12760    13907    +1147     
+ Misses       2348     1207    -1141     
Impacted Files Coverage Δ
pymc/logprob/scan.py 96.92% <100.00%> (+77.34%) ⬆️
pymc/backends/base.py 85.83% <0.00%> (+0.42%) ⬆️
pymc/backends/arviz.py 96.39% <0.00%> (+1.03%) ⬆️
pymc/distributions/distribution.py 96.61% <0.00%> (+1.23%) ⬆️
pymc/variational/minibatch_rv.py 100.00% <0.00%> (+1.72%) ⬆️
pymc/backends/__init__.py 89.18% <0.00%> (+2.70%) ⬆️
pymc/logprob/cumsum.py 100.00% <0.00%> (+3.12%) ⬆️
pymc/step_methods/metropolis.py 86.56% <0.00%> (+3.41%) ⬆️
pymc/variational/updates.py 92.11% <0.00%> (+3.44%) ⬆️
pymc/smc/sampling.py 86.61% <0.00%> (+5.63%) ⬆️
... and 21 more

@ricardoV94 ricardoV94 merged commit 49aacf4 into pymc-devs:main Mar 8, 2023
@ricardoV94 ricardoV94 deleted the fix_scan_infer_logp branch June 6, 2023 03:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants